Conversation
…il in ModelTrain (5504)
aviruthen
left a comment
There was a problem hiding this comment.
🤖 AI Code Review
This PR fixes PipelineVariable support in ModelTrainer but has several issues: removed type annotations without replacement, a magic string constant, and the test file referenced in the description is not included in the diff. The approach is reasonable but needs refinement.
|
|
||
| def _validate_training_image_and_algorithm_name( | ||
| self, training_image: Optional[str], algorithm_name: Optional[str] | ||
| self, training_image, algorithm_name |
There was a problem hiding this comment.
Missing type annotations: The type annotations for training_image and algorithm_name were removed entirely. Per SDK coding standards (PEP 484), all public/private methods must retain type annotations. Since these parameters now accept both str and PipelineVariable, please use the appropriate union type:
def _validate_training_image_and_algorithm_name(
self, training_image: str | PipelineVariable | None, algorithm_name: str | PipelineVariable | None
):Or if StrPipeVar is already defined as a type alias in the codebase, use that.
|
|
||
| Args: | ||
| image (str): The image URI | ||
| image: The image URI (str or PipelineVariable) |
There was a problem hiding this comment.
Removed type annotation: Same issue here — the image parameter lost its type annotation. Please restore it with the correct union type:
def _get_repo_name_from_image(image: str | PipelineVariable) -> str:|
|
||
| def _validate_training_image_and_algorithm_name( | ||
| self, training_image: Optional[str], algorithm_name: Optional[str] | ||
| self, training_image, algorithm_name |
There was a problem hiding this comment.
Missing test file in the diff: The PR description references test_model_trainer_pipeline_variable.py but this file is not included in the changed files. Please ensure the test file is included in the PR. Without tests, we cannot verify the fix works or guard against regressions.
aviruthen
left a comment
There was a problem hiding this comment.
🤖 AI Code Review
This PR fixes PipelineVariable support in ModelTrainer by handling non-string PipelineVariable objects in validation and utility functions. The approach is reasonable, but there are some issues: a line exceeds 100 characters, test fixtures have significant duplication that should be extracted, and the from __future__ import annotations import is missing in favor of the older from __future__ import absolute_import.
| mock_session = MagicMock() | ||
| mock_session.boto_region_name = "us-east-1" | ||
| mock_session.default_bucket.return_value = "my-bucket" | ||
| mock_session.default_bucket_prefix = None |
There was a problem hiding this comment.
The mock session and mock_defaults setup is duplicated across 4 test methods (test_training_image_accepts_parameter_string, test_algorithm_name_accepts_parameter_string, test_environment_values_accept_parameter_string, test_plain_string_values_still_work). Extract this into a @pytest.fixture to reduce duplication and improve maintainability:
@pytest.fixture
def mock_session():
session = MagicMock()
session.boto_region_name = "us-east-1"
session.default_bucket.return_value = "my-bucket"
session.default_bucket_prefix = None
return sessionAnd similarly for the mock_defaults patching.
| mock_defaults.get_role.return_value = "arn:aws:iam::123456789012:role/SageMakerRole" | ||
| mock_defaults.get_base_job_name.return_value = "test-job" | ||
| mock_defaults.get_compute.return_value = Compute( | ||
| instance_type="ml.m5.xlarge", instance_count=1 |
There was a problem hiding this comment.
Consider adding a test case for _validate_training_image_and_algorithm_name where one argument is a PipelineVariable and the other is None — this is the primary success case the fix enables. The current tests test_training_image_accepts_parameter_string and test_algorithm_name_accepts_parameter_string test this indirectly through full ModelTrainer construction, but a direct unit test of the validation method (like the rejection tests at lines 199-238) would be more focused and faster.
aviruthen
left a comment
There was a problem hiding this comment.
🤖 AI Code Review
This PR fixes PipelineVariable support in ModelTrainer by handling PipelineVariable objects that don't support standard boolean coercion. The approach is reasonable, but there are a few issues: a duplicate import in utils.py, the validation logic could be simplified, and the test file has imports inside test methods rather than at module level.
| from datetime import datetime | ||
| from typing import Literal, Any | ||
|
|
||
| from typing import Union |
There was a problem hiding this comment.
Duplicate import: Union is imported here from typing, but there's already a from typing import Literal, Any on line 24. Consolidate into a single import statement:
from typing import Literal, Any, UnionAlso, since the module already imports PipelineVariable from sagemaker.core.workflow.parameters on line 30, and from __future__ import annotations is not present, consider adding it to enable PEP 604 union syntax (str | PipelineVariable) per SDK conventions.
| # PipelineVariable objects do not support standard boolean coercion | ||
| # (__bool__ raises TypeError), so we use isinstance checks to detect | ||
| # them as truthy values during validation. | ||
| has_image = isinstance(training_image, PipelineVariable) or bool(training_image) |
There was a problem hiding this comment.
The logic isinstance(training_image, PipelineVariable) or bool(training_image) will raise TypeError if training_image is a PipelineVariable that doesn't support __bool__ — but since isinstance short-circuits via or, this is actually safe. However, consider simplifying to:
has_image = training_image is not None and training_image != ""
has_algo = algorithm_name is not None and algorithm_name != ""This avoids calling bool() entirely and is more explicit about what "not provided" means (None or empty string). The is not None check naturally handles PipelineVariable objects correctly.
|
|
||
|
|
||
| def _get_repo_name_from_image(image: str) -> str: | ||
| def _get_repo_name_from_image(image: Union[str, PipelineVariable]) -> str: |
There was a problem hiding this comment.
The return type annotation says -> str but when a PipelineVariable is passed, it returns a string placeholder, so the annotation is technically correct. However, consider documenting in the docstring that the placeholder _PIPELINE_VARIABLE_IMAGE_PLACEHOLDER is returned for PipelineVariable inputs, so downstream callers understand the behavior.
| from sagemaker.train.model_trainer import ModelTrainer | ||
| from sagemaker.train.configs import Compute | ||
|
|
||
| param = ParameterString( |
There was a problem hiding this comment.
The from sagemaker.train.model_trainer import ModelTrainer import is repeated inside every test method in this class. Move it to the top of the file with the other imports. Inline imports in tests add unnecessary noise and are not consistent with SDK test conventions.
| _TEST_IMAGE_URI = ( | ||
| "683313688378.dkr.ecr.us-east-1.amazonaws.com/" | ||
| "sagemaker-xgboost:1.0-1-cpu-py3" | ||
| ) |
There was a problem hiding this comment.
Nit: This test image URI contains a hardcoded region (us-east-1) and account ID (683313688378). While this is acceptable for unit tests since it's just a string constant and not used to make actual API calls, consider using a clearly fake account ID (e.g., 123456789012) for consistency with the mock session fixture below.
| assert trainer.training_image == _TEST_IMAGE_URI | ||
|
|
||
| def test_validation_accepts_pipeline_variable_image_none_algo(self): | ||
| """Test validation accepts PipelineVariable image with None algorithm.""" |
There was a problem hiding this comment.
Using ModelTrainer.__new__(ModelTrainer) to bypass __init__ and directly test the validation method is fragile — it creates an uninitialized object. If _validate_training_image_and_algorithm_name ever accesses self attributes, these tests will break with confusing errors. Consider either:
- Making
_validate_training_image_and_algorithm_namea@staticmethod(it doesn't useself), or - Using the existing
mock_train_defaultsfixture to construct a proper instance and test through the public interface.
|
|
||
| class TestSafeSerializeWithPipelineVariable: | ||
| """Tests for safe_serialize handling of PipelineVariable objects.""" | ||
|
|
There was a problem hiding this comment.
The TestSafeSerializeWithPipelineVariable tests verify safe_serialize behavior with PipelineVariable, but the PR diff doesn't show any changes to safe_serialize. If safe_serialize already handled PipelineVariable correctly, these tests are documenting existing behavior (which is fine), but it would be good to note that in the test class docstring. If safe_serialize needed changes, those changes should be included in this PR.
Description
PipelineVariable Support in ModelTrainer Fields (GH#5524)
This PR ensures that
ModelTrainerfields that acceptStrPipeVar(Union ofstrandPipelineVariable) work correctly whenPipelineVariableobjects (e.g.,ParameterString) are passed.Changes
sagemaker-train/src/sagemaker/train/utils.py: Updated_get_repo_name_from_imageto handlePipelineVariableobjects gracefully by returning a default name instead of attempting string operations on non-string types.sagemaker-train/src/sagemaker/train/model_trainer.py: Updated_validate_training_image_and_algorithm_nameto properly detectPipelineVariableinstances as truthy values during validation, sincePipelineVariableobjects may not support standard boolean coercion.Testing
Verified with unit tests in
test_model_trainer_pipeline_variable.pythat:training_image,algorithm_name,training_input_modeacceptParameterStringenvironmentdict values acceptParameterStringint) are still rejectedRelated Issue
Related issue: 5504
Changes Made
No response from agent
AI-Generated PR
This PR was automatically generated by the PySDK Issue Agent.
Merge Checklist
prefix: descriptionformat